Initial Exploration of the P/NBD Model
In this workbook we introduce the various different BTYD models, starting with a discussion of the underlying theory.
1 Background Theory
Before we start working on fitting and using the various Buy-Till-You-Die models, we first need to discuss the basic underlying theory and model.
In this model, we assume a customer becomes ‘alive’ to the business at the first purchase and then makes purchases stochastically but at a steady-rate for a period of time, and then ‘dies’ - i.e. becomes inactive to the business - hence the use of “Buy-Till-You-Die”.
Thus, at a high level these models decompose into modelling the transaction events using distributions such as the Poisson or Negative Binomial, and then modelling the ‘dropout’ process using some other method.
A number of BTYD models exist and for this workshop we will focus on the BG/NBD model - the Beta-Geometric Negative Binomial Distribution model (though we will discuss the P/NBD model also).
These models require only two pieces of information about each customer’s purchasing history: the “recency” (when the last transaction occurred) and “frequency” (the count of transactions made by that customer in a specified time period).
The notation used to represent this information is
\[ X = (x, \, t_x, \, T), \] where
\[ \begin{eqnarray*} x &=& \text{the number of transactions}, \\ T &=& \text{the observed time period}, \\ t_x &=& \text{the time since the last transaction}. \end{eqnarray*} \]
From this summary data we can fit most BTYD models.
2 BTYD Models
There are a number of different statistical approaches to building BTYD models - relying on a number of different assumptions about how the various recency, frequency and monetary values are modelled.
We now discuss a number of different ways of modelling this.
2.1 Pareto/Negative-Binomial Distribution (P/NBD) Model
The P/NBD model relies on five assumptions:
- While active, the number of transactions made by a customer follows a Poisson process with transaction rate \(\lambda\).
- Heterogeneity in \(\lambda\) follows a Gamma distribution \(\Gamma(\lambda \, | \, \alpha, r)\) with shape \(r\) and rate \(\alpha\).
- Each customer has an unobserved ‘lifetime’ of length \(\tau\). This point at which the customer becomes inactive is distributed as an exponential with dropout rate \(\mu\).
- Heterogeneity in dropout rates across customers follows a Gamma distribution \(\Gamma(\mu \, | \, s, \beta)\) with shape parameter \(s\) and rate parameter \(\beta\).
- The transaction rate \(\lambda\) and the dropout rate \(\mu\) vary independently across customers.
As before, we express this in mathematical notation as:
\[ \begin{eqnarray*} \lambda &\sim& \Gamma(\alpha, r), \\ \mu &\sim& \Gamma(s, \beta), \\ \tau &\sim& \text{Exponential}(\mu) \end{eqnarray*} \]
2.2 Beta-Geometric/Negative-Binomial Distribution (BG/NBD) Model
This model relies on a number of base assumptions, somewhat similar to the P/NBD model but modelling lifetime with a different process:
- While active, the number of transactions made by a customer follows a Poisson process with transaction rate \(\lambda\).
- Heterogeneity in \(\lambda\) follows a Gamma distribution \(\Gamma(\lambda \, | \, \alpha, r)\) with parameters shape \(r\) and rate \(\alpha\).
- After any transaction, a customer becomes inactive with probability \(p\).
- Heterogeneity in \(p\) follows a Beta distribution \(B(p \, | \, a, b)\) with shape parameters \(a\) and \(b\).
- The transaction rate \(\lambda\) and the dropout probability \(p\) vary independently across customers.
Note that it follows from the above assumptions that the probability of a customer being ‘alive’ after any transaction is given by the Geometric distribution, and hence the Beta-Geometric in the name.
To put this into more formal mathematical notation, we have:
\[ \begin{eqnarray*} \lambda &\sim& \Gamma(\alpha, r), \\ P(\text{alive}, k) &\sim& \text{Geometric}(p, k), \\ p &\sim& \text{Beta}(a, b) \end{eqnarray*} \]
3 Initial P/NBD Models
We start by modelling the P/NBD model using our synthetic datasets before we try to model real-life data.
3.1 Load Long Time-frame Synthetic Data
customer_cohortdata_tbl <- read_rds("data/synthdata_longframe_cohort_tbl.rds")
customer_cohortdata_tbl %>% glimpse()## Rows: 50,000
## Columns: 4
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0001", "C201101_0002", "C201101_0003", "C20110…
## $ cohort_qtr [3m[38;5;246m<chr>[39m[23m "2011 Q1", "2011 Q1", "2011 Q1", "2011 Q1", "2011 Q1", …
## $ cohort_ym [3m[38;5;246m<chr>[39m[23m "2011 01", "2011 01", "2011 01", "2011 01", "2011 01", …
## $ first_tnx_date [3m[38;5;246m<date>[39m[23m 2011-01-01, 2011-01-01, 2011-01-01, 2011-01-01, 2011-0…
customer_simparams_tbl <- read_rds("data/synthdata_longframe_simparams_tbl.rds")
customer_simparams_tbl %>% glimpse()## Rows: 50,000
## Columns: 9
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0001", "C201101_0002", "C201101_0003", "C2011…
## $ cohort_qtr [3m[38;5;246m<chr>[39m[23m "2011 Q1", "2011 Q1", "2011 Q1", "2011 Q1", "2011 Q1",…
## $ cohort_ym [3m[38;5;246m<chr>[39m[23m "2011 01", "2011 01", "2011 01", "2011 01", "2011 01",…
## $ first_tnx_date [3m[38;5;246m<date>[39m[23m 2011-01-01, 2011-01-01, 2011-01-01, 2011-01-01, 2011-…
## $ customer_mu [3m[38;5;246m<dbl>[39m[23m 0.04783829, 0.10238990, 0.03607158, 0.10886867, 0.0871…
## $ customer_tau [3m[38;5;246m<dbl>[39m[23m 16.59526247, 3.98578810, 1.36755838, 16.12324067, 1.61…
## $ customer_lambda [3m[38;5;246m<dbl>[39m[23m 0.37963949, 0.08399882, 0.27328828, 0.32830425, 0.0561…
## $ customer_nu [3m[38;5;246m<dbl>[39m[23m 1.704865659, 0.162742489, 0.874737616, 1.005165071, 0.…
## $ customer_p [3m[38;5;246m<dbl>[39m[23m 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,…
customer_transactions_tbl <- read_rds("data/synthdata_longframe_transactions_tbl.rds")
customer_transactions_tbl %>% glimpse()## Rows: 269,702
## Columns: 4
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0002", "C201101_0018", "C201101_0001", "C201101…
## $ tnx_timestamp [3m[38;5;246m<dttm>[39m[23m 2011-01-01 00:20:28, 2011-01-01 00:43:10, 2011-01-01 02…
## $ invoice_id [3m[38;5;246m<chr>[39m[23m "T20110101-0001", "T20110101-0002", "T20110101-0003", "T…
## $ tnx_amount [3m[38;5;246m<dbl>[39m[23m 547.93, 90.67, 57.90, 143.08, 113.37, 22.20, 574.21, 112…
We re-produce the visualisation of the transaction times we used in previous workbooks.
plot_tbl <- customer_transactions_tbl %>%
group_nest(customer_id, .key = "cust_data") %>%
filter(map_int(cust_data, nrow) > 3) %>%
slice_sample(n = 30) %>%
unnest(cust_data)
ggplot(plot_tbl, aes(x = tnx_timestamp, y = customer_id)) +
geom_line() +
geom_point() +
labs(
x = "Date",
y = "Customer ID",
title = "Visualisation of Customer Transaction Times"
) +
theme(axis.text.y = element_text(size = 10))3.2 Derive the Log-likelihood Function
We now turn our attention to deriving the log-likelihood model for the P/NBD model.
We assume that we know that a given customer has made \(x\) transactions after the initial one over an observed period of time \(T\), and we label these transactions \(t_1\), \(t_2\), …, \(t_x\).
To model the likelihood for this observation, we need to consider two possibilities: one for where the customer is still ‘alive’ at \(T\), and one where the customer has ‘died’ by \(T\).
In the first instance, the likelihood is the product of the observations of each transaction, multiplied by the likelihood of the customer still being alive at time \(T\).
Because we are modelling the transaction counts as a Poisson process, this corresponds to the times between events following an exponential distribution, and so both the transaction times and the lifetime likelihoods use the exponential.
This gives us:
\[ \begin{eqnarray*} L(\lambda \, | \, t_1, t_2, ..., t_x, T, \, \tau > T) &=& \lambda e^{-\lambda t_1} \lambda e^{-\lambda(t_2 - t_1)} ... \lambda e^{-\lambda (t_x - t_{x-1})} e^{-\lambda(T - t)} \\ &=& \lambda^x e^{-\lambda T} \end{eqnarray*} \]
and we can combine this with the likelihood of the lifetime of the customer \(\tau\) being greater than the observation window \(T\),
\[ P(\tau > T \, | \, \mu) = e^{-\mu T} \]
For the second case, the customer becomes inactive at some time \(\tau\) in the interval \((t_x, T]\), and so the likelihood is
\[ \begin{eqnarray*} L(\lambda \, | \, t_1, t_2, ..., t_x, T, \, \tau > T) &=& \lambda e^{-\lambda t_1} \lambda e^{-\lambda(t_2 - t_1)} ... \lambda e^{-\lambda (t_x - t_{x-1})} e^{-\lambda(\tau - t_x)} \\ &=& \lambda^x e^{-\lambda \tau} \end{eqnarray*} \]
In both cases we do not need the times of the individual transactions, and all we need are the values \((x, t_x, T)\).
As we cannot observe \(\tau\), we want to remove the conditioning on \(\tau\) by integrating it out.
\[ \begin{eqnarray*} L(\lambda, \mu \, | \, x, t_x, T) &=& L(\lambda \, | \, t_1, t_2, ..., t_x, T, \, \tau > T) \, P(\tau > T \, | \, \mu) + \int^T_{t_x} L(\lambda \, | \, x, T, \text{ inactive at } (t_x, T] ) \, f(\tau \, | \mu) d\tau \\ &=& \lambda^x e^{-\lambda T} e^{\mu T} + \lambda^x \int^T_{t_x} e^{-\lambda \tau} \mu e^{-\mu \tau} d\tau \\ &=& \lambda^x e^{-(\lambda + \mu)T} + \frac{\lambda^x \mu}{\lambda + \mu} e^{-(\lambda + \mu) t_x} + \frac{\lambda^x \mu}{\lambda + \mu} e^{-(\lambda + \mu) T} \\ &=& \frac{\lambda^x \mu}{\lambda + \mu} e^{-(\lambda + \mu) t_x} + \frac{\lambda^{x+1} \mu}{\lambda + \mu} e^{-(\lambda + \mu) T} \end{eqnarray*} \]
In Stan, we do not calculate the likelihoods but the Log-likelihood, so we need
to take the log of this expression. This creates a problem, as we have no easy
way to calculate \(\log(a + b)\). As this expression occurs a lot, Stan provides
a log_sum_exp(), which is defined by
\[ \log \, (a + b) = \text{log_sum_exp}(\log a, \log b) \]
\[ \begin{eqnarray*} LL(\lambda, \mu \, | \, x, t_x, T) &=& \log \left( \frac{\lambda^x \, \mu}{\lambda + \mu} \left( e^{-(\lambda + \mu) t_x} + \lambda e^{-(\lambda + \mu) T} \right) \right) \\ &=& x \log \lambda + \log \mu -log(\lambda + \mu) + \text{log_sum_exp}(-(\lambda + \mu) \, t_x, \; \log \lambda - (\lambda + \mu) \, T) \end{eqnarray*} \]
This is the log-likelihood model we want to fit in Stan.
3.3 Construct Datasets
Having loaded the synthetic data we need to construct a number of datasets of derived values.
customer_summarystats_tbl <- customer_transactions_tbl %>%
calculate_transaction_cbs_data(last_date = as.Date("2018-12-31"))
customer_summarystats_tbl %>% glimpse()## Rows: 44,306
## Columns: 6
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0001", "C201101_0002", "C201101_0003", "C20110…
## $ first_tnx_date [3m[38;5;246m<dttm>[39m[23m 2011-01-01 02:19:57, 2011-01-01 00:20:28, 2011-01-01 1…
## $ last_tnx_date [3m[38;5;246m<dttm>[39m[23m 2011-03-15 15:59:06, 2011-01-20 08:47:25, 2011-01-01 1…
## $ x [3m[38;5;246m<dbl>[39m[23m 6, 1, 0, 4, 0, 3, 0, 2, 5, 3, 1, 1, 23, 0, 1, 2, 1, 22,…
## $ t_x [3m[38;5;246m<dbl>[39m[23m 10.5098368, 2.7645783, 0.0000000, 11.2458596, 0.0000000…
## $ T_cal [3m[38;5;246m<dbl>[39m[23m 417.2718, 417.2837, 417.2001, 417.2551, 417.1894, 417.1…
As before, we construct a number of subsets of the data for use later on with the modelling and create some data subsets.
shuffle_tbl <- customer_summarystats_tbl %>%
slice_sample(n = nrow(.), replace = FALSE)
id_50 <- shuffle_tbl %>% head(50) %>% pull(customer_id) %>% sort()
id_1000 <- shuffle_tbl %>% head(1000) %>% pull(customer_id) %>% sort()
id_5000 <- shuffle_tbl %>% head(5000) %>% pull(customer_id) %>% sort()
id_10000 <- shuffle_tbl %>% head(10000) %>% pull(customer_id) %>% sort()We then construct some fit data based on these values.
fit_1000_data_tbl <- customer_summarystats_tbl %>% filter(customer_id %in% id_1000)
fit_1000_data_tbl %>% glimpse()## Rows: 1,000
## Columns: 6
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0025", "C201101_0084", "C201101_0091", "C20110…
## $ first_tnx_date [3m[38;5;246m<dttm>[39m[23m 2011-01-02 18:08:43, 2011-01-07 08:04:07, 2011-01-07 2…
## $ last_tnx_date [3m[38;5;246m<dttm>[39m[23m 2011-02-28 14:14:05, 2012-07-22 20:11:23, 2011-01-07 2…
## $ x [3m[38;5;246m<dbl>[39m[23m 2, 29, 0, 9, 0, 0, 16, 2, 3, 0, 6, 0, 0, 2, 7, 100, 9, …
## $ t_x [3m[38;5;246m<dbl>[39m[23m 8.119581, 80.357865, 0.000000, 21.078214, 0.000000, 0.0…
## $ T_cal [3m[38;5;246m<dbl>[39m[23m 417.0348, 416.3805, 416.2896, 416.2141, 416.2384, 415.0…
fit_10000_data_tbl <- customer_summarystats_tbl %>% filter(customer_id %in% id_10000)
fit_10000_data_tbl %>% glimpse()## Rows: 10,000
## Columns: 6
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0002", "C201101_0019", "C201101_0020", "C20110…
## $ first_tnx_date [3m[38;5;246m<dttm>[39m[23m 2011-01-01 00:20:28, 2011-01-02 16:32:41, 2011-01-02 2…
## $ last_tnx_date [3m[38;5;246m<dttm>[39m[23m 2011-01-20 08:47:25, 2011-01-02 16:32:41, 2011-01-02 2…
## $ x [3m[38;5;246m<dbl>[39m[23m 1, 0, 0, 2, 2, 4, 2, 4, 13, 7, 1, 0, 0, 8, 5, 1, 0, 29,…
## $ t_x [3m[38;5;246m<dbl>[39m[23m 2.764578, 0.000000, 0.000000, 8.119581, 5.265854, 21.74…
## $ T_cal [3m[38;5;246m<dbl>[39m[23m 417.2837, 417.0444, 417.0236, 417.0348, 417.0938, 417.0…
Finally, we also want to recreate our transaction visualisation for the first 50 customers randomly selected.
plot_tbl <- customer_transactions_tbl %>%
filter(customer_id %in% id_50)
ggplot(plot_tbl, aes(x = tnx_timestamp, y = customer_id)) +
geom_line() +
geom_point() +
labs(
x = "Date",
y = "Customer ID",
title = "Visualisation of Customer Transaction Times"
) +
theme(axis.text.y = element_text(size = 10))3.4 Write Data
id_1000 %>% write_rds("data/id_1000.rds")
id_5000 %>% write_rds("data/id_5000.rds")
id_10000 %>% write_rds("data/id_10000.rds")
fit_1000_data_tbl %>% write_rds("data/fit_1000_longframe_data_tbl.rds")
fit_10000_data_tbl %>% write_rds("data/fit_10000_longframe_data_tbl.rds")
customer_summarystats_tbl %>% write_rds("data/customer_summarystats_longframe_tbl.rds")4 Fit Initial P/NBD Model
We now construct our Stan model and prepare to fit it with our synthetic dataset.
Before we start on that, we set a few parameters for the workbook to organise our Stan code.
stan_modeldir <- "stan_models"
stan_codedir <- "stan_code"We start with the Stan model.
## functions {
## #include util_functions.stan
## }
##
## data {
## int<lower=1> n; // number of customers
##
## vector<lower=0>[n] t_x; // time to most recent purchase
## vector<lower=0>[n] T_cal; // total observation time
## vector<lower=0>[n] x; // number of purchases observed
##
## real<lower=0> lambda_mn; // prior mean for lambda
## real<lower=0> lambda_cv; // prior cv for lambda
##
## real<lower=0> mu_mn; // prior mean for mu
## real<lower=0> mu_cv; // prior mean for mu
## }
##
## transformed data {
## real<lower=0> r = 1 / (lambda_cv * lambda_cv);
## real<lower=0> alpha = 1 / (lambda_cv * lambda_cv * lambda_mn);
##
## real<lower=0> s = 1 / (mu_cv * mu_cv);
## real<lower=0> beta = 1 / (mu_cv * mu_cv * mu_mn);
## }
##
##
## parameters {
## vector<lower=0>[n] lambda; // purchase rate
## vector<lower=0>[n] mu; // lifetime dropout rate
## }
##
##
## model {
## // setting priors
## lambda ~ gamma(r, alpha);
## mu ~ gamma(s, beta);
##
## target += calculate_pnbd_loglik(n, lambda, mu, x, t_x, T_cal);
## }
##
## generated quantities {
## vector[n] p_alive; // Probability that they are still "alive"
##
## p_alive = 1 ./ (1 + mu ./ (mu + lambda) .* (exp((lambda + mu) .* (T_cal - t_x)) - 1));
## }
This file contains a few new features of Stan - named file includes and
user-defined functions - calculate_pnbd_loglik. We look at this file here:
## real calculate_pnbd_loglik(int n, vector lambda, vector mu,
## data vector x, data vector t_x, data vector T_cal) {
## // likelihood
## vector[n] t1;
## vector[n] t2;
##
## vector[n] lpm;
## vector[n] lht;
## vector[n] rht;
##
## lpm = lambda + mu;
##
## lht = log(lambda) - lpm .* T_cal;
## rht = log(mu) - lpm .* t_x;
##
## t1 = x .* log(lambda) - log(lpm);
##
## for (i in 1:n) {
## t2[i] = log_sum_exp(lht[i], rht[i]);
## }
##
## return(sum(t1) + sum(t2));
## }
##
##
## real calculate_bgnbd_loglik(int n, vector lambda, vector p,
## data vector x, data vector t_x, data vector T_cal) {
## // likelihood
## vector[n] t1;
## vector[n] t2;
##
## vector[n] lht;
## vector[n] rht;
##
## lht = log(p) + (x-1) .* log(1-p) + x .* log(lambda) - lambda .* t_x;
## rht = x .* log(1-p) + x .* log(lambda) - lambda .* T_cal;
##
## for(i in 1:n) {
## t2[i] = log_sum_exp(lht[i], rht[i]);
## }
##
## return(sum(t2));
## }
We now compile this model using CmdStanR.
pnbd_fixed_stanmodel <- cmdstan_model(
"stan_code/pnbd_fixed.stan",
include_paths = stan_codedir,
pedantic = TRUE,
dir = stan_modeldir
)We then use this compiled model with our data to produce a fit of the data.
stan_modelname <- "pnbd_fixed"
stanfit_prefix <- str_c("fit_", stan_modelname)
stan_data_lst <- fit_1000_data_tbl %>%
select(customer_id, x, t_x, T_cal) %>%
compose_data(
lambda_mn = 0.25,
lambda_cv = 1.00,
mu_mn = 0.10,
mu_cv = 1.00,
)
pnbd_fixed_stanfit <- pnbd_fixed_stanmodel$sample(
data = stan_data_lst,
chains = 4,
iter_warmup = 500,
iter_sampling = 500,
seed = 4201,
save_warmup = TRUE,
output_dir = stan_modeldir,
output_basename = stanfit_prefix,
)## Running MCMC with 4 chains, at most 8 in parallel...
##
## Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 1 finished in 15.8 seconds.
## Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 4 finished in 16.2 seconds.
## Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 3 finished in 16.6 seconds.
## Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 2 finished in 16.8 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 16.4 seconds.
## Total execution time: 17.1 seconds.
pnbd_fixed_stanfit$summary()## # A tibble: 3,001 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -1.65e+4 -1.65e+4 34.1 34.6 -1.66e+4 -1.65e+4 1.00 635.
## 2 lambda[1] 1.96e-1 1.66e-1 0.127 0.109 4.55e-2 4.39e-1 1.00 3982.
## 3 lambda[2] 3.45e-1 3.41e-1 0.0618 0.0592 2.52e-1 4.52e-1 1.01 4310.
## 4 lambda[3] 1.40e-1 8.31e-2 0.164 0.0925 5.61e-3 4.95e-1 1.00 2831.
## 5 lambda[4] 3.66e-1 3.51e-1 0.117 0.118 1.97e-1 5.83e-1 0.999 4289.
## 6 lambda[5] 1.49e-1 8.34e-2 0.181 0.0983 3.82e-3 5.28e-1 1.00 2645.
## 7 lambda[6] 1.35e-1 8.15e-2 0.167 0.0928 5.16e-3 4.55e-1 0.999 2872.
## 8 lambda[7] 1.40e-1 1.37e-1 0.0352 0.0334 8.71e-2 2.00e-1 1.01 5067.
## 9 lambda[8] 1.80e-1 1.51e-1 0.121 0.103 3.75e-2 4.11e-1 1.00 3065.
## 10 lambda[9] 2.75e-1 2.45e-1 0.149 0.136 8.42e-2 5.53e-1 1.00 4007.
## # … with 2,991 more rows, and 1 more variable: ess_tail <dbl>
We have some basic HMC-based validity statistics we can check.
pnbd_fixed_stanfit$cmdstan_diagnose()## Processing csv files: /home/rstudio/workshop/stan_models/fit_pnbd_fixed-1.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_fixed-2.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_fixed-3.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_fixed-4.csvWarning: non-fatal error reading adaptation data
##
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
4.1 Visual Diagnostics of the Sample Validity
Now that we have a sample from the posterior distribution we need to create a few different visualisations of the diagnostics.
parameter_subset <- c(
"lambda[1]", "lambda[2]", "lambda[3]", "lambda[4]",
"mu[1]", "mu[2]", "mu[3]", "mu[4]"
)
pnbd_fixed_stanfit$draws(inc_warmup = TRUE) %>%
mcmc_trace(
pars = parameter_subset,
n_warmup = 500
) +
ggtitle("Full Traceplots of Some Lambda and Mu Values")As the warmup is skewing the y-axis somewhat, we repeat this process without the warmup.
pnbd_fixed_stanfit$draws(inc_warmup = FALSE) %>%
mcmc_trace(pars = parameter_subset) +
expand_limits(y = 0) +
labs(
x = "Iteration",
y = "Value",
title = "Traceplot of Sample of Lambda and Mu Values"
) +
theme(axis.text.x = element_text(size = 10))A common MCMC diagnostic is \(\hat{R}\) - which is a measure of the ‘similarity’ of the chains.
pnbd_fixed_stanfit %>%
rhat(pars = c("lambda", "mu")) %>%
mcmc_rhat() +
ggtitle("Plot of Parameter R-hat Values")Related to this quantity is the concept of effective sample size, \(N_{eff}\), an estimate of the size of the sample from a statistical information point of view.
pnbd_fixed_stanfit %>%
neff_ratio(pars = c("lambda", "mu")) %>%
mcmc_neff() +
ggtitle("Plot of Parameter Effective Sample Sizes")Finally, we also want to look at autocorrelation in the chains for each parameter.
pnbd_fixed_stanfit$draws() %>%
mcmc_acf(pars = parameter_subset) +
ggtitle("Autocorrelation Plot of Sample Values")As before, this first fit has a comprehensive run of fit diagnostics, but for the sake of brevity in later models we will show only the traceplots once we are satisfied with the validity of the sample.
4.2 Check Model Fit
As we are still working with synthetic data, we know the true values for each customer and so we can check how good our model is at recovering the true values on a customer-by-customer basis.
As in previous workbooks, we build our validation datasets and then check the distribution of \(q\)-values for both \(\lambda\) and \(\mu\) across the customer base.
pnbd_fixed_valid_lst <- create_pnbd_posterior_validation_data(
stanfit = pnbd_fixed_stanfit,
data_tbl = fit_1000_data_tbl,
simparams_tbl = customer_simparams_tbl
)
pnbd_fixed_valid_lst$lambda_qval_plot %>% plot()pnbd_fixed_valid_lst$mu_qval_plot %>% plot()5 Fit Alternate Prior P/NBD Model
We now repeat all of the above but with an alternative set of priors and compare the outputs to give us an idea of the sensitivity of the inference to the input priors.
We will repeat this a few times, so we start by increasing the co-efficient of variation in the priors. We keep everything else constant, including the seed used by Stan.
stan_modelname <- "pnbd_fixed2"
stanfit_prefix <- str_c("fit_", stan_modelname)
stan_data_lst <- fit_1000_data_tbl %>%
select(customer_id, x, t_x, T_cal) %>%
compose_data(
lambda_mn = 0.25,
lambda_cv = 2.00,
mu_mn = 0.10,
mu_cv = 2.00,
)
pnbd_fixed2_stanfit <- pnbd_fixed_stanmodel$sample(
data = stan_data_lst,
chains = 4,
iter_warmup = 500,
iter_sampling = 500,
seed = 4201,
save_warmup = TRUE,
output_dir = stan_modeldir,
output_basename = stanfit_prefix,
)## Running MCMC with 4 chains, at most 8 in parallel...
##
## Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 3 finished in 37.3 seconds.
## Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 2 finished in 38.9 seconds.
## Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 4 finished in 39.2 seconds.
## Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 1 finished in 41.5 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 39.2 seconds.
## Total execution time: 41.7 seconds.
pnbd_fixed2_stanfit$summary()## # A tibble: 3,001 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -1.21e+4 -1.21e+4 39.9 4.17e+1 -1.21e+4 -1.20e+4 1.01 544.
## 2 lambda[1] 1.71e-1 1.41e-1 0.128 1.07e-1 2.62e-2 4.25e-1 1.00 2504.
## 3 lambda[2] 3.48e-1 3.44e-1 0.0649 6.46e-2 2.50e-1 4.63e-1 1.00 4383.
## 4 lambda[3] 5.28e-2 2.28e-3 0.165 3.37e-3 6.91e-8 2.64e-1 0.999 1276.
## 5 lambda[4] 3.80e-1 3.69e-1 0.128 1.24e-1 1.96e-1 6.11e-1 0.999 3694.
## 6 lambda[5] 5.17e-2 2.31e-3 0.173 3.42e-3 1.00e-7 2.76e-1 1.00 1080.
## 7 lambda[6] 5.01e-2 2.34e-3 0.150 3.46e-3 2.85e-7 2.68e-1 1.00 1229.
## 8 lambda[7] 1.36e-1 1.34e-1 0.0340 3.21e-2 8.35e-2 1.95e-1 1.00 4566.
## 9 lambda[8] 1.59e-1 1.27e-1 0.125 1.01e-1 2.28e-2 4.16e-1 1.00 2887.
## 10 lambda[9] 2.80e-1 2.44e-1 0.174 1.56e-1 6.81e-2 6.08e-1 1.00 2492.
## # … with 2,991 more rows, and 1 more variable: ess_tail <dbl>
We have some basic HMC-based validity statistics we can check.
pnbd_fixed2_stanfit$cmdstan_diagnose()## Processing csv files: /home/rstudio/workshop/stan_models/fit_pnbd_fixed2-1.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_fixed2-2.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_fixed2-3.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_fixed2-4.csvWarning: non-fatal error reading adaptation data
##
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## The following parameters had split R-hat greater than 1.05:
## p_alive[433], p_alive[474], p_alive[476], p_alive[920]
## Such high values indicate incomplete mixing and biased estimation.
## You should consider regularizating your model with additional prior information or a more effective parameterization.
##
## Processing complete.
5.1 Visual Diagnostics of the Sample Validity
We do not repeat the full set of validation checks here, but look at the plot of effective stepsizes.
pnbd_fixed2_stanfit %>%
neff_ratio(pars = c("lambda", "mu")) %>%
mcmc_neff() +
ggtitle("Plot of Parameter Effective Sample Sizes for Alternate Priors")5.2 Check Model Fit
As we are still working with synthetic data, we know the true values for each customer and so we can check how good our model is at recovering the true values on a customer-by-customer basis.
As in previous workbooks, we build our validation datasets and then check the distribution of \(q\)-values for both \(\lambda\) and \(\mu\) across the customer base.
pnbd_fixed2_valid_lst <- create_pnbd_posterior_validation_data(
stanfit = pnbd_fixed2_stanfit,
data_tbl = fit_1000_data_tbl,
simparams_tbl = customer_simparams_tbl
)
pnbd_fixed2_valid_lst$lambda_qval_plot %>% plot()pnbd_fixed2_valid_lst$mu_qval_plot %>% plot()We now construct some error-bar plots for a subset of the customers to get an idea of the differences across the two posteriors.
comparison_plot_tbl <- list(
fixed = pnbd_fixed_valid_lst$validation_tbl,
fixed2 = pnbd_fixed2_valid_lst$validation_tbl
) %>%
bind_rows(.id = "post_label") %>%
group_by(post_label, customer_id) %>%
summarise(
.groups = "drop",
p10 = quantile(post_lambda, 0.10),
p25 = quantile(post_lambda, 0.25),
p50 = quantile(post_lambda, 0.50),
p75 = quantile(post_lambda, 0.75),
p90 = quantile(post_lambda, 0.90),
customer_lambda = customer_lambda %>% unique()
) %>%
group_nest(customer_id) %>%
slice_sample(n = 50) %>%
unnest(data)
ggplot(comparison_plot_tbl) +
geom_point(aes(x = customer_id, y = customer_lambda)) +
geom_errorbar(
aes(x = customer_id, ymin = p25, ymax = p75, colour = post_label),
position = position_dodge(width = 0.75), width = 0, size = 3,
) +
geom_errorbar(
aes(x = customer_id, ymin = p10, ymax = p90, colour = post_label),
position = position_dodge(width = 0.75), width = 0, size = 1,
) +
theme(
axis.text.x = element_text(angle = 90, vjust = 0.5)
)6 Assessing the P/NBD Models
We now focus on assess the various versions of our models without the benefit of knowing the ‘true’ answer as we will not have this information in our real-world applications.
A key derived quantity to employ is \(p_{alive}(T)\), the probability that the customer is still ‘alive’ at the point of observation of the data.
6.1 Calculating p_alive
The probability that a customer with purchase history \((x, t_x, T)\) is ‘alive’ at time \(T\) is the probability that the (unobserved) time at which he becomes inactive (\(\tau\)) occurs after T, that is \(P(\tau > T)\).
We apply Bayes’ Theorem to give us:
\[ \begin{eqnarray*} P(\tau > T \, | \, \lambda, \mu, x, t_x, T) &=& \frac{L(\lambda \, | \, x, T, \tau > T) \, P(\tau > T \, | \, \mu)} {L(\lambda, \mu \, | \, x, t_x, T)} \\ &=& \frac{\lambda^x e^{−(λ+μ)T}} {L(\lambda, \mu \, | \, x, t_x, T)} \end{eqnarray*} \]
We know the likelihood for this model, so can now substitute this in:
\[ \begin{eqnarray*} P(\tau > T \, | \, \lambda, \mu, x, t_x, T) &=& \frac{\lambda^x e^{−(λ+μ)T}} {\frac{\lambda^x \mu}{\lambda + \mu} e^{-(\lambda + \mu) t_x} + \frac{\lambda^{x+1} \mu}{\lambda + \mu} e^{-(\lambda + \mu) T}} \\ &=& \frac{\lambda^x \, e^{-(\lambda + \mu)T}} {\lambda^x e^{-(\lambda + \mu)T} \{1 + (\frac{u}{\lambda + \mu}) [e^{-(\lambda + \mu)(t_x - T)} - 1 ]\}} \\ &=& \frac{1}{1 + (\frac{u}{\lambda + \mu}) [e^{(\lambda + \mu)(T - t_x)} - 1 ]} \end{eqnarray*} \]
We now want to verify this calculation in our posterior estimates. To do this, we take a number of different cohorts of customers, and visually inspect our transaction visualisation.
pnbd_fixed_palive_summary_tbl <- pnbd_fixed_valid_lst$validation_tbl %>%
group_by(customer_id) %>%
summarise(
.groups = "drop",
p_alive_p10 = quantile(p_alive, 0.10),
p_alive_p25 = quantile(p_alive, 0.25),
p_alive_p50 = quantile(p_alive, 0.50),
p_alive_p75 = quantile(p_alive, 0.75),
p_alive_p90 = quantile(p_alive, 0.90),
p_alive_range50 = p_alive_p75 - p_alive_p25,
p_alive_range80 = p_alive_p90 - p_alive_p10,
)
pnbd_fixed_palive_summary_tbl %>% glimpse()## Rows: 1,000
## Columns: 8
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0025", "C201101_0084", "C201101_0091", "C2011…
## $ p_alive_p10 [3m[38;5;246m<dbl>[39m[23m 1.063788e-85, 5.803798e-65, 1.310305e-102, 9.548350e-1…
## $ p_alive_p25 [3m[38;5;246m<dbl>[39m[23m 2.433787e-64, 9.276485e-59, 8.515800e-67, 1.568615e-87…
## $ p_alive_p50 [3m[38;5;246m<dbl>[39m[23m 8.500835e-47, 1.778625e-52, 7.813795e-42, 1.564965e-70…
## $ p_alive_p75 [3m[38;5;246m<dbl>[39m[23m 1.652795e-31, 9.765512e-47, 6.032905e-24, 1.333278e-56…
## $ p_alive_p90 [3m[38;5;246m<dbl>[39m[23m 8.281591e-22, 2.221119e-41, 7.266608e-14, 9.218306e-47…
## $ p_alive_range50 [3m[38;5;246m<dbl>[39m[23m 1.652795e-31, 9.765512e-47, 6.032905e-24, 1.333278e-56…
## $ p_alive_range80 [3m[38;5;246m<dbl>[39m[23m 8.281591e-22, 2.221119e-41, 7.266608e-14, 9.218306e-47…
We now take the customers that are highly likely to no longer be active, and highly likely to be active, and so we use the 10% and 90% percentiles.
likely_active_id <- pnbd_fixed_palive_summary_tbl %>%
filter(p_alive_p10 > 0.95) %>%
pull(customer_id)
plot_tbl <- customer_transactions_tbl %>%
filter(
tnx_timestamp < as.Date("2019-01-01"),
customer_id %in% likely_active_id
)
ggplot(plot_tbl, aes(x = tnx_timestamp, y = customer_id)) +
geom_line() +
geom_point() +
geom_vline(aes(xintercept = as.POSIXct("2019-01-01")), colour = "red") +
labs(
x = "Date",
y = "Customer ID",
title = "Visualisation of Transaction Times for Likely Active Customers"
) +
theme(axis.text.y = element_text(size = 10))This is useful, but the longer period of time prevents us from seeing the more recent transactions clearly, so we focus on the final year.
likely_active_id <- pnbd_fixed_palive_summary_tbl %>%
filter(p_alive_p10 > 0.95) %>%
pull(customer_id)
plot_tbl <- customer_transactions_tbl %>%
filter(
tnx_timestamp < as.Date("2019-01-01"),
tnx_timestamp >= as.Date("2018-01-01"),
customer_id %in% likely_active_id
)
ggplot(plot_tbl, aes(x = tnx_timestamp, y = customer_id)) +
geom_line() +
geom_point() +
geom_vline(aes(xintercept = as.POSIXct("2019-01-01")), colour = "red") +
labs(
x = "Date",
y = "Customer ID",
title = "Visualisation of Transaction Times for Likely Active Customers"
) +
theme(axis.text.y = element_text(size = 10))We see that most of the active customers have reasonably recent transactions.
likely_inactive_id <- pnbd_fixed_palive_summary_tbl %>%
filter(p_alive_p90 < 0.05) %>%
pull(customer_id)
plot_tbl <- customer_transactions_tbl %>%
filter(
tnx_timestamp < as.Date("2019-01-01"),
customer_id %in% likely_inactive_id
)
ggplot(plot_tbl, aes(x = tnx_timestamp, y = customer_id)) +
geom_line(alpha = 0.1) +
geom_point(alpha = 0.1, size = 1) +
geom_vline(aes(xintercept = as.POSIXct("2019-01-01")), colour = "red") +
labs(
x = "Date",
y = "Customer",
title = "Visualisation of Transaction Times for Likely Inactive Customers"
) +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
)Finally, we want to look at the customers with the most uncertainty in their
value for p_alive. There are a number of ways to do this, but we start by
looking at the range in values.
First we look at the distribution of these ranges.
ggplot(pnbd_fixed_palive_summary_tbl) +
geom_histogram(aes(x = p_alive_range80), binwidth = 0.02) +
labs(
x = "80% Interval Range",
y = "Frequency",
title = "Distribution of Range of Values for the 80% Credibility Range for p_alive"
)We see that most customers have their credibility intervals in a reasonably tight range, and so we can just look at customers where this range is at least 0.3.
plot_tbl <- pnbd_fixed_palive_summary_tbl %>%
filter(p_alive_range80 > 0.3) %>%
select(customer_id, p_alive_range80) %>%
inner_join(customer_transactions_tbl, by = "customer_id") %>%
filter(
tnx_timestamp < as.Date("2019-01-01"),
tnx_timestamp >= as.Date("2018-01-01")
)
ggplot(plot_tbl, aes(x = tnx_timestamp, y = customer_id, colour = p_alive_range80)) +
geom_line(alpha = 0.5) +
geom_point(alpha = 0.5, size = 1) +
geom_vline(aes(xintercept = as.POSIXct("2019-01-01")), colour = "red") +
scale_colour_gradient(low = "blue", high = "red") +
labs(
x = "Date",
y = "Customer",
colour = "Interval Range",
title = "Visualisation of Transaction Times for Non-Confident p_alive Customers"
) +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
)6.2 Constructing Model Validation Approaches.
We now can combine each of these three inferred parameters, \(\lambda\), \(\mu\),
p_alive as an input to data generating simulations and compare the observed
data to our simulated data.
Recall that our models were fit on data that occurred before a specific date: for this data the cutoff date is 2018-12-31.
We can approach this generative process in a number of different ways, but the most natural appears to be the following (repeated for each customer in the training dataset):
- Using
p_alive, simulate if the customer is still active. - If active, simulate \(\tau'\), the remaining lifetime from observed \(T\) using \(\mu\) – \(\tau' \sim \text{Exponential}(\mu)\).
- Simulate times \(t_x'\) between events from \(T\), \(t_x' \sim \text{Exponential}(\lambda)\).
- Keep all values where the cumulative sum of these time intervals is less than \(\tau\).
Note that the above approach is simplified due to the ‘memoryless’ nature of the Exponential distribution. Also, we could just simulate event counts using \(x' \sim \text{Poisson}(\lambda \tau')\) but it is more useful to also simulate time intervals between the transactions.
To perform this simulation, we write a function
generate_pnbd_validation_transactions() to perform the required calculations.
Thinking ahead, this function will also include parameters for the transaction
amount, but for these models we set both those values to 1.
pnbd_fixed_validsims_tbl <- pnbd_fixed_valid_lst$validation_tbl %>%
group_nest(customer_id, .key = "cust_params") %>%
mutate(
sim_file = glue(
"precompute/pnbd_fixed/sims_pnbd_fixed_{customer_id}.rds"
),
chunk_data = future_map2(
sim_file, cust_params,
run_pnbd_simulations_chunk,
.options = furrr_options(
globals = c(
"calculate_event_times", "rgamma_mucv", "gamma_mucv2shaperate",
"generate_pnbd_validation_transactions"
),
packages = c("tidyverse", "fs"),
scheduling = FALSE,
seed = 421
),
.progress = TRUE
)
) %>%
select(-cust_params) %>%
unnest(chunk_data)
pnbd_fixed_validsims_tbl %>% glimpse()## Rows: 1,000
## Columns: 3
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0025", "C201101_0084", "C201101_0091", "C201101_0…
## $ sim_file [3m[38;5;246m<glue>[39m[23m "precompute/pnbd_fixed/sims_pnbd_fixed_C201101_0025.rds",…
## $ chunk_data [3m[38;5;246m<lgl>[39m[23m FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FA…
We also want the summary statistics for the transactions in the final year that is not included in the fitted data.
obs_2019_stats_tbl <- customer_transactions_tbl %>%
filter(
tnx_timestamp >= as.POSIXct("2019-01-01")
) %>%
group_by(customer_id) %>%
summarise(
.groups = "drop",
tnx_count = n(),
first_tnx = min(tnx_timestamp),
last_tnx = max(tnx_timestamp)
)
obs_2019_stats_tbl %>% glimpse()## Rows: 7,145
## Columns: 4
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201104_0224", "C201104_0449", "C201105_0314", "C201108_0…
## $ tnx_count [3m[38;5;246m<int>[39m[23m 6, 3, 6, 11, 2, 1, 11, 20, 3, 2, 9, 12, 1, 12, 1, 1, 32, 1…
## $ first_tnx [3m[38;5;246m<dttm>[39m[23m 2019-01-01 10:41:33, 2019-05-12 20:40:15, 2019-02-09 19:1…
## $ last_tnx [3m[38;5;246m<dttm>[39m[23m 2019-07-05 12:16:44, 2019-07-14 07:02:47, 2019-11-25 13:2…
We now want to combine this data to start our investigation of the model outputs. We then compare the distribution of simulated transaction counts for our posterior against the observed count in the data by calculating the \(q\)-value.
pnbd_fixed_1000_validation_tbl <- pnbd_fixed_validsims_tbl %>%
group_by(customer_id) %>%
summarise(
.groups = "drop",
sim_count = list(map_int(sim_file, ~ .x %>% read_rds() %>% nrow()))
) %>%
left_join(obs_2019_stats_tbl, by = "customer_id") %>%
replace_na(list(tnx_count = 0)) %>%
mutate(
tnx_count_qval = map2_dbl(sim_count, tnx_count, ~ ecdf(.x)(.y))
)
pnbd_fixed_1000_validation_tbl %>% glimpse()## Rows: 1,000
## Columns: 6
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0025", "C201101_0084", "C201101_0091", "C20110…
## $ sim_count [3m[38;5;246m<list>[39m[23m 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, …
## $ tnx_count [3m[38;5;246m<int>[39m[23m 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ first_tnx [3m[38;5;246m<dttm>[39m[23m NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA…
## $ last_tnx [3m[38;5;246m<dttm>[39m[23m NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA…
## $ tnx_count_qval [3m[38;5;246m<dbl>[39m[23m 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
We first look at the distribution of these \(q\)-values.
ggplot(pnbd_fixed_1000_validation_tbl) +
geom_histogram(aes(x = tnx_count_qval), bins = 50) +
labs(
x = "q-Value for Transaction Count",
y = "Frequency",
title = "Histogram of Transaction Count ECDF q-Values"
)q-values of 1.0 are inflated in this data as we have a number of customers that are marked to be highly likely to be inactive and so have a posterior distribution of all zero count transactions and observed no transactions in the validation set, so we can remove these data points.
pnbd_fixed_1000_validation_tbl %>%
filter(
!(tnx_count == 0 & (tnx_count_qval == 1))
) %>%
filter(tnx_count == 0) %>%
glimpse()## Rows: 967
## Columns: 6
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0025", "C201101_0084", "C201101_0091", "C20110…
## $ sim_count [3m[38;5;246m<list>[39m[23m 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, …
## $ tnx_count [3m[38;5;246m<int>[39m[23m 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ first_tnx [3m[38;5;246m<dttm>[39m[23m NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA…
## $ last_tnx [3m[38;5;246m<dttm>[39m[23m NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA…
## $ tnx_count_qval [3m[38;5;246m<dbl>[39m[23m 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
From inspection on this, we see there are a number of customers that have no observed transactions but also have a \(q\)-value of less than 1. In many cases this is fine, but we also see from the data there are a number of customers that have values close to 1 suggestion the simulations are almost all zero but also have a small count of simulations with at least one transaction.
This needs some further exploration, so we compare these customers to our summary data on the posterior distributions.
pnbd_fixed_1000_validation_tbl %>%
filter(
!(tnx_count == 0 & (tnx_count_qval == 1))
) %>%
filter(tnx_count == 0) %>%
inner_join(pnbd_fixed_palive_summary_tbl, by = "customer_id") %>%
glimpse()## Rows: 967
## Columns: 13
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201101_0025", "C201101_0084", "C201101_0091", "C2011…
## $ sim_count [3m[38;5;246m<list>[39m[23m 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000,…
## $ tnx_count [3m[38;5;246m<int>[39m[23m 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ first_tnx [3m[38;5;246m<dttm>[39m[23m NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ last_tnx [3m[38;5;246m<dttm>[39m[23m NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ tnx_count_qval [3m[38;5;246m<dbl>[39m[23m 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ p_alive_p10 [3m[38;5;246m<dbl>[39m[23m 1.063788e-85, 5.803798e-65, 1.310305e-102, 9.548350e-1…
## $ p_alive_p25 [3m[38;5;246m<dbl>[39m[23m 2.433787e-64, 9.276485e-59, 8.515800e-67, 1.568615e-87…
## $ p_alive_p50 [3m[38;5;246m<dbl>[39m[23m 8.500835e-47, 1.778625e-52, 7.813795e-42, 1.564965e-70…
## $ p_alive_p75 [3m[38;5;246m<dbl>[39m[23m 1.652795e-31, 9.765512e-47, 6.032905e-24, 1.333278e-56…
## $ p_alive_p90 [3m[38;5;246m<dbl>[39m[23m 8.281591e-22, 2.221119e-41, 7.266608e-14, 9.218306e-47…
## $ p_alive_range50 [3m[38;5;246m<dbl>[39m[23m 1.652795e-31, 9.765512e-47, 6.032905e-24, 1.333278e-56…
## $ p_alive_range80 [3m[38;5;246m<dbl>[39m[23m 8.281591e-22, 2.221119e-41, 7.266608e-14, 9.218306e-47…
Looking at this data identifies some unusual aspects to how customers with values of \(p_alive\) that are essentially zero are still producing non-zero transaction counts in our simulation, so we need to look at this.
In particular, we look at the first customer in this summary dataset, customer C201104_0418.
pnbd_fixed_validsims_tbl %>%
filter(
customer_id == "C201212_0330"
) %>%
mutate(data = map(sim_file, read_rds)) %>%
unnest(data) %>%
select(customer_id, post_lambda, post_mu, p_alive) %>%
arrange(desc(p_alive)) %>%
glimpse()## Rows: 2,000
## Columns: 4
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201212_0330", "C201212_0330", "C201212_0330", "C201212_0…
## $ post_lambda [3m[38;5;246m<dbl>[39m[23m 0.003349470, 0.001005140, 0.009539530, 0.004652750, 0.0037…
## $ post_mu [3m[38;5;246m<dbl>[39m[23m 0.000351128, 0.000610718, 0.000359878, 0.000921058, 0.0012…
## $ p_alive [3m[38;5;246m<dbl>[39m[23m 0.8277770, 0.8003220, 0.5633740, 0.5604560, 0.5085010, 0.4…
We see these values come from a very low value of both \(\lambda\) and \(\mu\), so
also want to see the values of the posterior that are more typical of the
customer by looking at values around the median of p_alive.
pnbd_fixed_validsims_tbl %>%
filter(
customer_id == "C201212_0330"
) %>%
mutate(data = map(sim_file, read_rds)) %>%
unnest(data) %>%
select(customer_id, draw_id, post_lambda, post_mu, p_alive) %>%
filter(
abs(1 - (p_alive / median(p_alive))) < 0.2
) %>%
arrange(draw_id) %>%
glimpse()## Rows: 5
## Columns: 5
## $ customer_id [3m[38;5;246m<chr>[39m[23m "C201212_0330", "C201212_0330", "C201212_0330", "C201212_0…
## $ draw_id [3m[38;5;246m<int>[39m[23m 142, 920, 1115, 1364, 1418
## $ post_lambda [3m[38;5;246m<dbl>[39m[23m 0.0184381, 0.1399050, 0.1506980, 0.2096500, 0.2356350
## $ post_mu [3m[38;5;246m<dbl>[39m[23m 0.20702000, 0.08775780, 0.07744400, 0.02223590, 0.00283791
## $ p_alive [3m[38;5;246m<dbl>[39m[23m 2.09472e-31, 2.49840e-31, 2.44135e-31, 2.67030e-31, 2.7255…
We also want to look at all the transactions for that customer.
customer_transactions_tbl %>%
filter(customer_id == "C201212_0330")## # A tibble: 1 × 4
## customer_id tnx_timestamp invoice_id tnx_amount
## <chr> <dttm> <chr> <dbl>
## 1 C201212_0330 2012-12-26 02:29:46 T20121226-0011 49.7
Finally, we want to check the posterior distribution for the alternate prior model and see what this looks like.
pnbd_fixed2_valid_lst$validation_tbl %>%
filter(customer_id == "C201212_0330") %>%
select(customer_id, draw_id, post_lambda, post_mu, p_alive) %>%
arrange(desc(p_alive))## # A tibble: 2,000 × 5
## customer_id draw_id post_lambda post_mu p_alive
## <chr> <int> <dbl> <dbl> <dbl>
## 1 C201212_0330 193 0.0000772 1.16e-10 1
## 2 C201212_0330 214 0.00000298 4.58e-10 1
## 3 C201212_0330 215 0.0000223 5.06e-10 1
## 4 C201212_0330 309 0.000000669 6.66e-10 1
## 5 C201212_0330 310 0.00000213 9.60e-10 1
## 6 C201212_0330 1065 0.00122 1.08e-11 1
## 7 C201212_0330 1066 0.00119 1.38e-11 1
## 8 C201212_0330 1067 0.0000430 4.81e-10 1
## 9 C201212_0330 1750 0.000300 8.17e-13 1
## 10 C201212_0330 1751 0.000937 3.16e-12 1
## # … with 1,990 more rows
We see this gives us silly values at least for low transaction count customers, especially for those customers with no extra transactions after the initial one.
That said, these outputs are useful
pnbd_fixed2_validsims_tbl <- pnbd_fixed2_valid_lst$validation_tbl %>%
group_nest(customer_id, .key = "cust_params") %>%
mutate(
sim_file = glue(
"precompute/pnbd_fixed2/sims_pnbd_fixed2_{customer_id}.rds"
),
chunk_data = future_map2(
sim_file, cust_params,
run_pnbd_simulations_chunk,
.options = furrr_options(
globals = c(
"calculate_event_times", "rgamma_mucv", "gamma_mucv2shaperate",
"generate_pnbd_validation_transactions"
),
packages = c("tidyverse", "fs"),
scheduling = FALSE,
seed = 421
),
.progress = TRUE
)
) %>%
select(-cust_params) %>%
unnest(chunk_data)
pnbd_fixed2_validsims_tbl %>% glimpse()## Rows: 1,000
## Columns: 3
## $ customer_id <chr> "C201101_0025", "C201101_0084", "C201101_0091", "C201101_0…
## $ sim_file <glue> "precompute/pnbd_fixed2/sims_pnbd_fixed2_C201101_0025.rds…
## $ chunk_data <lgl> FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FA…
6.3 Comparing Simulation Data
Before we move on to other aspects of model validation, we also want to check the simulation of transaction counts against the observed counts.
6.3.1 Assessing the pnbd_fixed Model
We first want to look at the total transaction count and count of customers transacted, comparing the observed amount to that observed in the actual data.
tnx_data_tbl <- obs_2019_stats_tbl %>%
semi_join(pnbd_fixed_validsims_tbl, by = "customer_id")
obs_customer_count <- tnx_data_tbl %>% nrow()
obs_total_tnx_count <- tnx_data_tbl %>% pull(tnx_count) %>% sum()
pnbd_fixed_tnx_simsumm_tbl <- pnbd_fixed_validsims_tbl %>%
mutate(data = map(sim_file, read_rds)) %>%
unnest(data) %>%
group_by(draw_id) %>%
summarise(
.groups = "drop",
sim_customer_count = length(sim_tnx_count[sim_tnx_count > 0]),
sim_total_tnx_count = sum(sim_tnx_count)
)
ggplot(pnbd_fixed_tnx_simsumm_tbl) +
geom_histogram(aes(x = sim_customer_count), binwidth = 1) +
geom_vline(aes(xintercept = obs_customer_count), colour = "red") +
labs(
x = "Simulated Customers With Transactions",
y = "Frequency",
title = "Histogram of Count of Customers Transacted",
subtitle = "Observed Count in Red"
)ggplot(pnbd_fixed_tnx_simsumm_tbl) +
geom_histogram(aes(x = sim_total_tnx_count), binwidth = 5) +
geom_vline(aes(xintercept = obs_total_tnx_count), colour = "red") +
labs(
x = "Simulated Transaction Count",
y = "Frequency",
title = "Histogram of Count of Total Transaction Count",
subtitle = "Observed Count in Red"
)6.3.2 Assessing the pnbd_fixed2 Model
We repeat this assessment for our pnbd_fixed2 model, which had a wider
coefficient of variation in our prior parameters.
tnx_data_tbl <- obs_2019_stats_tbl %>%
semi_join(pnbd_fixed2_validsims_tbl, by = "customer_id")
obs_customer_count <- tnx_data_tbl %>% nrow()
obs_total_tnx_count <- tnx_data_tbl %>% pull(tnx_count) %>% sum()
pnbd_fixed2_tnx_simsumm_tbl <- pnbd_fixed2_validsims_tbl %>%
mutate(data = map(sim_file, read_rds)) %>%
unnest(data) %>%
group_by(draw_id) %>%
summarise(
.groups = "drop",
sim_customer_count = length(sim_tnx_count[sim_tnx_count > 0]),
sim_total_tnx_count = sum(sim_tnx_count)
)
ggplot(pnbd_fixed2_tnx_simsumm_tbl) +
geom_histogram(aes(x = sim_customer_count), binwidth = 1) +
geom_vline(aes(xintercept = obs_customer_count), colour = "red") +
labs(
x = "Simulated Customers With Transactions",
y = "Frequency",
title = "Histogram of Count of Customers Transacted",
subtitle = "Observed Count in Red"
)ggplot(pnbd_fixed2_tnx_simsumm_tbl) +
geom_histogram(aes(x = sim_total_tnx_count), binwidth = 5) +
geom_vline(aes(xintercept = obs_total_tnx_count), colour = "red") +
labs(
x = "Simulated Transaction Count",
y = "Frequency",
title = "Histogram of Count of Total Transaction Count",
subtitle = "Observed Count in Red"
)We see that in general, our models are over-estimating the count of customers transacting in the following year, as well as the total count of transactions made.
This suggests we should try an alternative prior, as our priors on both lifetime and transaction rate may skew too high.
7 Fitting Lower-CV Prior Model
Our existing fits have a number of different issues, so it is worth trying to refit this model with a lower coefficient of variation – this should solve the lower end of the distribution though the tradeoff means that that the right tail of the distribution is less pronounced. Given we are skewing high in our transaction counts though, this may not be a problem.
stan_modelname <- "pnbd_fixed3"
stanfit_prefix <- str_c("fit_", stan_modelname)
stan_data_lst <- fit_1000_data_tbl %>%
select(customer_id, x, t_x, T_cal) %>%
compose_data(
lambda_mn = 0.25,
lambda_cv = 0.60,
mu_mn = 0.10,
mu_cv = 0.60,
)
pnbd_fixed3_stanfit <- pnbd_fixed_stanmodel$sample(
data = stan_data_lst,
chains = 4,
iter_warmup = 500,
iter_sampling = 500,
seed = 4201,
save_warmup = TRUE,
output_dir = stan_modeldir,
output_basename = stanfit_prefix,
)## Running MCMC with 4 chains, at most 8 in parallel...
##
## Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 4 finished in 10.7 seconds.
## Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 3 finished in 11.3 seconds.
## Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 1 finished in 13.7 seconds.
## Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 2 finished in 14.1 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 12.4 seconds.
## Total execution time: 14.4 seconds.
pnbd_fixed3_stanfit$summary()## # A tibble: 3,001 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -2.70e+4 -2.70e+4 34.1 34.0 -2.71e+4 -2.69e+4 1.00 549.
## 2 lambda[1] 2.12e-1 1.97e-1 0.105 0.0952 7.61e-2 4.08e-1 1.00 2789.
## 3 lambda[2] 3.37e-1 3.34e-1 0.0600 0.0564 2.44e-1 4.41e-1 1.00 3172.
## 4 lambda[3] 1.94e-1 1.65e-1 0.135 0.110 3.82e-2 4.61e-1 1.00 1951.
## 5 lambda[4] 3.42e-1 3.30e-1 0.105 0.103 1.94e-1 5.34e-1 1.00 3082.
## 6 lambda[5] 1.94e-1 1.67e-1 0.128 0.114 4.31e-2 4.45e-1 1.00 1912.
## 7 lambda[6] 1.99e-1 1.71e-1 0.126 0.110 4.44e-2 4.35e-1 1.00 2458.
## 8 lambda[7] 1.45e-1 1.42e-1 0.0336 0.0336 9.48e-2 2.04e-1 0.999 2949.
## 9 lambda[8] 2.04e-1 1.88e-1 0.101 0.0962 7.32e-2 3.87e-1 1.00 2691.
## 10 lambda[9] 2.62e-1 2.45e-1 0.113 0.111 1.09e-1 4.58e-1 1.00 2596.
## # … with 2,991 more rows, and 1 more variable: ess_tail <dbl>
We have some basic HMC-based validity statistics we can check.
pnbd_fixed3_stanfit$cmdstan_diagnose()## Processing csv files: /home/rstudio/workshop/stan_models/fit_pnbd_fixed3-1.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_fixed3-2.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_fixed3-3.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_fixed3-4.csvWarning: non-fatal error reading adaptation data
##
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
7.1 Visual Diagnostics of the Sample Validity
We do not repeat the full set of validation checks here, but look at the plot of the traces, effective stepsizes, and the autocorrelation.
pnbd_fixed3_stanfit$draws() %>%
mcmc_trace(pars = parameter_subset) +
ggtitle("Traceplot of Sample Parameters")pnbd_fixed3_stanfit %>%
neff_ratio(pars = c("lambda", "mu")) %>%
mcmc_neff() +
ggtitle("Plot of Parameter Effective Sample Sizes for Low-CV Priors")pnbd_fixed3_stanfit$draws() %>%
mcmc_acf(pars = parameter_subset) +
ggtitle("Autocorrelation Plot of Sample Parameters")7.2 Comparing the Model Validation
We now construct our posterior dataset.
pnbd_fixed3_validation_tbl <- pnbd_fixed3_stanfit %>%
recover_types(fit_1000_data_tbl) %>%
spread_draws(lambda[customer_id], mu[customer_id], p_alive[customer_id]) %>%
ungroup() %>%
select(customer_id, draw_id = .draw, post_lambda = lambda, post_mu = mu, p_alive)
pnbd_fixed3_validation_tbl %>% glimpse()## Rows: 2,000,000
## Columns: 5
## $ customer_id <chr> "C201101_0025", "C201101_0025", "C201101_0025", "C201101_0…
## $ draw_id <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,…
## $ post_lambda <dbl> 0.2255100, 0.2779170, 0.1495770, 0.2960230, 0.3571490, 0.1…
## $ post_mu <dbl> 0.0406934, 0.1907740, 0.0306729, 0.0637177, 0.1122680, 0.0…
## $ p_alive <dbl> 3.47344e-47, 1.43115e-83, 5.73541e-32, 7.33789e-64, 1.8094…
We now want to check that customer from before to see if the posterior has any anomalous values.
pnbd_fixed3_validation_tbl %>%
filter(customer_id == "C201212_0330") %>%
select(customer_id, draw_id, post_lambda, post_mu, p_alive) %>%
arrange(desc(p_alive))## # A tibble: 2,000 × 5
## customer_id draw_id post_lambda post_mu p_alive
## <chr> <int> <dbl> <dbl> <dbl>
## 1 C201212_0330 1494 0.0124 0.0344 5.72e- 7
## 2 C201212_0330 214 0.0358 0.0169 2.05e- 7
## 3 C201212_0330 372 0.0325 0.0236 5.32e- 8
## 4 C201212_0330 46 0.0354 0.0211 5.31e- 8
## 5 C201212_0330 499 0.0290 0.0268 5.14e- 8
## 6 C201212_0330 634 0.0364 0.0230 2.10e- 8
## 7 C201212_0330 917 0.0381 0.0275 2.78e- 9
## 8 C201212_0330 14 0.0255 0.0413 1.29e- 9
## 9 C201212_0330 1799 0.0350 0.0334 9.81e-10
## 10 C201212_0330 1493 0.0195 0.0479 9.11e-10
## # … with 1,990 more rows
The customer has a very low posterior value for p_alive.
7.3 Generate Model Validation Simulations
We now want to generate our validation simulations using our existing functions, like we did for previous models.
pnbd_fixed3_validsims_tbl <- pnbd_fixed3_validation_tbl %>%
group_nest(customer_id, .key = "cust_params") %>%
mutate(
sim_file = glue(
"precompute/pnbd_fixed3/sims_pnbd_fixed3_{customer_id}.rds"
),
chunk_data = future_map2(
sim_file, cust_params,
run_pnbd_simulations_chunk,
.options = furrr_options(
globals = c(
"calculate_event_times", "rgamma_mucv", "gamma_mucv2shaperate",
"generate_pnbd_validation_transactions"
),
packages = c("tidyverse", "fs"),
scheduling = FALSE,
seed = 421
),
.progress = TRUE
)
) %>%
select(-cust_params) %>%
unnest(chunk_data)
pnbd_fixed3_validsims_tbl %>% glimpse()## Rows: 1,000
## Columns: 3
## $ customer_id <chr> "C201101_0025", "C201101_0084", "C201101_0091", "C201101_0…
## $ sim_file <glue> "precompute/pnbd_fixed3/sims_pnbd_fixed3_C201101_0025.rds…
## $ chunk_data <lgl> FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FA…
Having constructed the validation posterior we now join this data to our observations and calculate the \(q\)-values as before.
pnbd_fixed3_1000_validation_tbl <- pnbd_fixed3_validsims_tbl %>%
group_by(customer_id) %>%
summarise(
.groups = "drop",
sim_count = list(map_int(sim_file, ~ .x %>% read_rds() %>% nrow()))
) %>%
left_join(obs_2019_stats_tbl, by = "customer_id") %>%
replace_na(list(tnx_count = 0)) %>%
mutate(
tnx_count_qval = map2_dbl(sim_count, tnx_count, ~ ecdf(.x)(.y))
)
pnbd_fixed3_1000_validation_tbl %>% glimpse()## Rows: 1,000
## Columns: 6
## $ customer_id <chr> "C201101_0025", "C201101_0084", "C201101_0091", "C20110…
## $ sim_count <list> 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, …
## $ tnx_count <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ first_tnx <dttm> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA…
## $ last_tnx <dttm> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA…
## $ tnx_count_qval <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
Once again, for now we want to remove all the customers with no transactions and where all their validation simulations have zero transactions.
plot_tbl <- pnbd_fixed3_1000_validation_tbl %>%
filter(!(tnx_count == 0 & tnx_count_qval == 1))
ggplot(plot_tbl) +
geom_histogram(aes(x = tnx_count_qval), bins = 50) +
labs(
x = "q-Value",
y = "Frequency",
title = "Histogram of q-Values From Posterior Simulation of Transactions Counts"
)Here we see we still have a heavy bias towards high \(q\)-values, so we need to see exactly what the issues are.
7.4 Assessing this Model
We now repeat our assessment approach for this new model, investigating our simulations by looking at transaction counts.
tnx_data_tbl <- obs_2019_stats_tbl %>%
semi_join(pnbd_fixed3_validsims_tbl, by = "customer_id")
obs_customer_count <- tnx_data_tbl %>% nrow()
obs_total_tnx_count <- tnx_data_tbl %>% pull(tnx_count) %>% sum()
pnbd_fixed3_tnx_simsumm_tbl <- pnbd_fixed3_validsims_tbl %>%
mutate(data = map(sim_file, read_rds)) %>%
unnest(data) %>%
group_by(draw_id) %>%
summarise(
.groups = "drop",
sim_customer_count = length(sim_tnx_count[sim_tnx_count > 0]),
sim_total_tnx_count = sum(sim_tnx_count)
)
ggplot(pnbd_fixed3_tnx_simsumm_tbl) +
geom_histogram(aes(x = sim_customer_count), binwidth = 1) +
geom_vline(aes(xintercept = obs_customer_count), colour = "red") +
labs(
x = "Simulated Customers With Transactions",
y = "Frequency",
title = "Histogram of Count of Customers Transacted",
subtitle = "Observed Count in Red"
)ggplot(pnbd_fixed3_tnx_simsumm_tbl) +
geom_histogram(aes(x = sim_total_tnx_count), binwidth = 5) +
geom_vline(aes(xintercept = obs_total_tnx_count), colour = "red") +
labs(
x = "Simulated Transaction Count",
y = "Frequency",
title = "Histogram of Count of Total Transaction Count",
subtitle = "Observed Count in Red"
)Overall, it looks like our choice of prior parameters has a big effect on our estimates, so we need to take this into account.
An improvement to this model would allow us to express uncertainty around these models, which is the next step.
8 Fit First Hierarchical Model
We now want to add a hierarchy to the model by adding hierarchical priors to our P/NBD model. In particular, we focus on adding a prior for \(\mu\) as we have more confidence in our estimates for \(\lambda\) and so we want to model our uncertainty in the lifetime of the customer.
## functions {
## #include util_functions.stan
## }
##
## data {
## int<lower=1> n; // number of customers
##
## vector<lower=0>[n] t_x; // time to most recent purchase
## vector<lower=0>[n] T_cal; // total observation time
## vector<lower=0>[n] x; // number of purchases observed
##
## real<lower=0> lambda_mn; // prior mean for lambda
## real<lower=0> lambda_cv; // prior cv for lambda
##
## real mu_mn_p1; // hyperprior p1 for mu mean
## real<lower=0> mu_mn_p2; // hyperprior p2 for mu mean
##
## real mu_cv_p1; // hyperprior p1 for mu cv
## real<lower=0> mu_cv_p2; // hyperprior p2 for mu cv
## }
##
## transformed data {
## real<lower=0> r = 1 / (lambda_cv * lambda_cv);
## real<lower=0> alpha = 1 / (lambda_cv * lambda_cv * lambda_mn);
## }
##
##
## parameters {
## real<lower=0> mu_mn;
## real<lower=0> mu_cv;
##
## vector<lower=0>[n] lambda; // purchase rate
## vector<lower=0>[n] mu; // lifetime dropout rate
## }
##
##
## transformed parameters {
## real<lower=0> s;
## real<lower=0> beta;
##
## s = 1 / (mu_cv * mu_cv);
## beta = 1 / (mu_cv * mu_cv * mu_mn);
## }
##
## model {
## // model the hyper-prior
## mu_mn ~ lognormal(mu_mn_p1, mu_mn_p2);
## mu_cv ~ lognormal(mu_cv_p1, mu_cv_p2);
##
## // setting priors
## lambda ~ gamma(r, alpha);
## mu ~ gamma(s, beta);
##
## target += calculate_pnbd_loglik(n, lambda, mu, x, t_x, T_cal);
## }
##
## generated quantities {
## vector[n] p_alive; // Probability that they are still "alive"
##
## p_alive = 1 ./ (1 + mu ./ (mu + lambda) .* (exp((lambda + mu) .* (T_cal - t_x)) - 1));
## }
We now compile this model using CmdStanR.
pnbd_hiermu_stanmodel <- cmdstan_model(
"stan_code/pnbd_hiermu.stan",
include_paths = stan_codedir,
pedantic = TRUE,
dir = stan_modeldir
)We then use this compiled model with our data to produce a fit of the data.
stan_modelname <- "pnbd_hiermu"
stanfit_prefix <- str_c("fit_", stan_modelname)
stan_data_lst <- fit_1000_data_tbl %>%
select(customer_id, x, t_x, T_cal) %>%
compose_data(
lambda_mn = 0.25,
lambda_cv = 1.00,
mu_mn_p1 = log(0.1) - 0.5 * (0.5)^2,
mu_mn_p2 = 0.5,
mu_cv_p1 = log(1) - 0.5 * (0.2)^2,
mu_cv_p2 = 0.2,
)
pnbd_hiermu_stanfit <- pnbd_hiermu_stanmodel$sample(
data = stan_data_lst,
chains = 4,
iter_warmup = 500,
iter_sampling = 500,
seed = 4201,
save_warmup = TRUE,
output_dir = stan_modeldir,
output_basename = stanfit_prefix,
)## Running MCMC with 4 chains, at most 8 in parallel...
##
## Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 1 finished in 23.6 seconds.
## Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 2 finished in 24.0 seconds.
## Chain 3 finished in 23.9 seconds.
## Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 4 finished in 24.2 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 23.9 seconds.
## Total execution time: 25.0 seconds.
pnbd_hiermu_stanfit$summary()## # A tibble: 3,005 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -1.40e+4 -1.40e+4 5.67e+1 5.46e+1 -1.41e+4 -1.39e+4 1.01 159.
## 2 mu_mn 1.15e-1 1.15e-1 9.06e-3 9.29e-3 1.01e-1 1.31e-1 1.01 262.
## 3 mu_cv 7.85e-1 7.84e-1 4.27e-2 4.34e-2 7.14e-1 8.56e-1 1.01 176.
## 4 lambda[1] 1.95e-1 1.69e-1 1.21e-1 1.08e-1 4.75e-2 4.33e-1 0.999 3978.
## 5 lambda[2] 3.45e-1 3.40e-1 6.20e-2 5.97e-2 2.50e-1 4.50e-1 1.00 3267.
## 6 lambda[3] 1.39e-1 8.85e-2 1.59e-1 9.82e-2 5.13e-3 4.48e-1 1.00 2507.
## 7 lambda[4] 3.65e-1 3.51e-1 1.16e-1 1.09e-1 2.00e-1 5.78e-1 1.00 3962.
## 8 lambda[5] 1.44e-1 8.38e-2 1.76e-1 9.38e-2 5.37e-3 4.75e-1 0.999 3054.
## 9 lambda[6] 1.46e-1 8.37e-2 1.76e-1 1.00e-1 5.00e-3 5.05e-1 1.00 2746.
## 10 lambda[7] 1.39e-1 1.37e-1 3.36e-2 3.41e-2 8.91e-2 1.98e-1 1.00 4115.
## # … with 2,995 more rows, and 1 more variable: ess_tail <dbl>
We have some basic HMC-based validity statistics we can check.
pnbd_hiermu_stanfit$cmdstan_diagnose()## Processing csv files: /home/rstudio/workshop/stan_models/fit_pnbd_hiermu-1.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_hiermu-2.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_hiermu-3.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_hiermu-4.csvWarning: non-fatal error reading adaptation data
##
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## The following parameters had fewer than 0.001 effective draws per transition:
## p_alive[2836], p_alive[4075]
## Such low values indicate that the effective sample size estimators may be biased high and actual performance may be substantially lower than quoted.
##
## The following parameters had split R-hat greater than 1.05:
## lambda[1101], lambda[1756], lambda[2836], lambda[4004], lambda[4075], lambda[4464], mu[1101], mu[2836], mu[4075], p_alive[5], p_alive[17], p_alive[38], p_alive[76], p_alive[84], p_alive[87], p_alive[91], p_alive[94], p_alive[123], p_alive[129], p_alive[131], p_alive[145], p_alive[150], p_alive[153], p_alive[166], p_alive[232], p_alive[239], p_alive[257], p_alive[274], p_alive[281], p_alive[299], p_alive[335], p_alive[346], p_alive[357], p_alive[409], p_alive[410], p_alive[443], p_alive[460], p_alive[552], p_alive[580], p_alive[651], p_alive[669], p_alive[674], p_alive[704], p_alive[711], p_alive[719], p_alive[756], p_alive[800], p_alive[815], p_alive[819], p_alive[845], p_alive[862], p_alive[893], p_alive[894], p_alive[904], p_alive[916], p_alive[920], p_alive[921], p_alive[923], p_alive[928], p_alive[1020], p_alive[1026], p_alive[1041], p_alive[1060], p_alive[1073], p_alive[1101], p_alive[1118], p_alive[1136], p_alive[1151], p_alive[1152], p_alive[1211], p_alive[1213], p_alive[1220], p_alive[1244], p_alive[1267], p_alive[1274], p_alive[1283], p_alive[1292], p_alive[1354], p_alive[1365], p_alive[1386], p_alive[1394], p_alive[1410], p_alive[1413], p_alive[1424], p_alive[1452], p_alive[1473], p_alive[1495], p_alive[1497], p_alive[1538], p_alive[1546], p_alive[1576], p_alive[1583], p_alive[1585], p_alive[1592], p_alive[1601], p_alive[1603], p_alive[1613], p_alive[1640], p_alive[1676], p_alive[1688], p_alive[1705], p_alive[1708], p_alive[1747], p_alive[1756], p_alive[1782], p_alive[1813], p_alive[1846], p_alive[1863], p_alive[1890], p_alive[1913], p_alive[1914], p_alive[1956], p_alive[1957], p_alive[1966], p_alive[1967], p_alive[1993], p_alive[2001], p_alive[2018], p_alive[2020], p_alive[2067], p_alive[2077], p_alive[2107], p_alive[2114], p_alive[2117], p_alive[2135], p_alive[2169], p_alive[2197], p_alive[2217], p_alive[2218], p_alive[2296], p_alive[2300], p_alive[2318], p_alive[2319], p_alive[2334], p_alive[2339], p_alive[2349], p_alive[2363], p_alive[2371], p_alive[2386], p_alive[2405], p_alive[2410], p_alive[2456], p_alive[2498], p_alive[2508], p_alive[2555], p_alive[2558], p_alive[2567], p_alive[2602], p_alive[2603], p_alive[2604], p_alive[2632], p_alive[2651], p_alive[2656], p_alive[2700], p_alive[2787], p_alive[2798], p_alive[2801], p_alive[2821], p_alive[2822], p_alive[2836], p_alive[2853], p_alive[2857], p_alive[2878], p_alive[2881], p_alive[2890], p_alive[2898], p_alive[2931], p_alive[2934], p_alive[2952], p_alive[2953], p_alive[2998], p_alive[3000], p_alive[3016], p_alive[3021], p_alive[3027], p_alive[3032], p_alive[3045], p_alive[3054], p_alive[3080], p_alive[3084], p_alive[3085], p_alive[3095], p_alive[3118], p_alive[3123], p_alive[3124], p_alive[3147], p_alive[3179], p_alive[3206], p_alive[3224], p_alive[3247], p_alive[3249], p_alive[3253], p_alive[3266], p_alive[3289], p_alive[3294], p_alive[3316], p_alive[3336], p_alive[3345], p_alive[3349], p_alive[3354], p_alive[3361], p_alive[3390], p_alive[3393], p_alive[3399], p_alive[3400], p_alive[3436], p_alive[3472], p_alive[3488], p_alive[3493], p_alive[3496], p_alive[3527], p_alive[3533], p_alive[3540], p_alive[3551], p_alive[3565], p_alive[3581], p_alive[3619], p_alive[3633], p_alive[3640], p_alive[3652], p_alive[3660], p_alive[3674], p_alive[3697], p_alive[3706], p_alive[3710], p_alive[3711], p_alive[3731], p_alive[3750], p_alive[3752], p_alive[3779], p_alive[3847], p_alive[3886], p_alive[3906], p_alive[3910], p_alive[3936], p_alive[3960], p_alive[3985], p_alive[3991], p_alive[4004], p_alive[4010], p_alive[4023], p_alive[4034], p_alive[4045], p_alive[4075], p_alive[4078], p_alive[4087], p_alive[4105], p_alive[4131], p_alive[4135], p_alive[4141], p_alive[4146], p_alive[4176], p_alive[4195], p_alive[4219], p_alive[4260], p_alive[4266], p_alive[4278], p_alive[4313], p_alive[4322], p_alive[4341], p_alive[4352], p_alive[4366], p_alive[4415], p_alive[4434], p_alive[4442], p_alive[4457], p_alive[4464], p_alive[4488], p_alive[4508], p_alive[4526], p_alive[4530], p_alive[4558], p_alive[4578], p_alive[4601], p_alive[4613], p_alive[4628], p_alive[4633], p_alive[4663], p_alive[4692], p_alive[4696], p_alive[4697], p_alive[4700], p_alive[4709]
## Such high values indicate incomplete mixing and biased estimation.
## You should consider regularizating your model with additional prior information or a more effective parameterization.
##
## Processing complete.
8.1 Visual Diagnostics of the Sample Validity
Now that we have a sample from the posterior distribution we need to create a few different visualisations of the diagnostics.
pnbd_hiermu_stanfit$draws(inc_warmup = FALSE) %>%
mcmc_trace(pars = parameter_subset) +
expand_limits(y = 0) +
labs(
x = "Iteration",
y = "Value",
title = "Traceplot of Sample of Lambda and Mu Values"
) +
theme(axis.text.x = element_text(size = 10))A common MCMC diagnostic is \(\hat{R}\) - which is a measure of the ‘similarity’ of the chains.
pnbd_hiermu_stanfit %>%
rhat(pars = c("lambda", "mu", "s", "beta", "mu_mn", "mu_cv")) %>%
mcmc_rhat() +
ggtitle("Plot of Parameter R-hat Values")Related to this quantity is the concept of effective sample size, \(N_{eff}\), an estimate of the size of the sample from a statistical information point of view.
pnbd_hiermu_stanfit %>%
neff_ratio(pars = c("lambda", "mu", "s", "beta", "mu_mn", "mu_cv")) %>%
mcmc_neff() +
ggtitle("Plot of Parameter Effective Sample Sizes")Finally, we also want to look at autocorrelation in the chains for each parameter.
pnbd_hiermu_stanfit$draws() %>%
mcmc_acf(pars = parameter_subset) +
ggtitle("Autocorrelation Plot of Sample Values")As before, this first fit has a comprehensive run of fit diagnostics, but for the sake of brevity in later models we will show only the traceplots once we are satisfied with the validity of the sample.
8.2 Validate the Hier-Mu Model Fit
We now want to validate this model by using our simulation technique.
We first extract our posterior by customer, and use this as the basis of our simulations.
pnbd_hiermu_validation_tbl <- pnbd_hiermu_stanfit %>%
recover_types(fit_1000_data_tbl) %>%
spread_draws(lambda[customer_id], mu[customer_id], p_alive[customer_id]) %>%
ungroup() %>%
select(customer_id, draw_id = .draw, post_lambda = lambda, post_mu = mu, p_alive)
pnbd_hiermu_validation_tbl %>% glimpse()## Rows: 2,000,000
## Columns: 5
## $ customer_id <chr> "C201101_0025", "C201101_0025", "C201101_0025", "C201101_0…
## $ draw_id <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,…
## $ post_lambda <dbl> 0.1314730, 0.2764690, 0.1216930, 0.2560520, 0.3791690, 0.1…
## $ post_mu <dbl> 0.1320430, 0.0795552, 0.1424040, 0.0274170, 0.1295670, 0.0…
## $ p_alive <dbl> 3.17970e-47, 2.65862e-63, 2.32973e-47, 4.71273e-50, 1.7695…
Having constructed our simulations inputs, we now generate our simulations.
pnbd_hiermu_validsims_tbl <- pnbd_hiermu_validation_tbl %>%
group_nest(customer_id, .key = "cust_params") %>%
mutate(
sim_file = glue(
"precompute/pnbd_hiermu/sims_pnbd_hiermu_{customer_id}.rds"
),
chunk_data = future_map2(
sim_file, cust_params,
run_pnbd_simulations_chunk,
.options = furrr_options(
globals = c(
"calculate_event_times", "rgamma_mucv", "gamma_mucv2shaperate",
"generate_pnbd_validation_transactions"
),
packages = c("tidyverse", "fs"),
scheduling = FALSE,
seed = 421
),
.progress = TRUE
)
) %>%
select(-cust_params) %>%
unnest(chunk_data)
pnbd_hiermu_validsims_tbl %>% glimpse()## Rows: 1,000
## Columns: 3
## $ customer_id <chr> "C201101_0025", "C201101_0084", "C201101_0091", "C201101_0…
## $ sim_file <glue> "precompute/pnbd_hiermu/sims_pnbd_hiermu_C201101_0025.rds…
## $ chunk_data <lgl> FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FA…
pnbd_hiermu_1000_validation_tbl <- pnbd_hiermu_validsims_tbl %>%
mutate(data = map(sim_file, read_rds)) %>%
unnest(data) %>%
group_by(customer_id) %>%
summarise(
.groups = "drop",
sim_count = list(map_int(sim_data, nrow))
) %>%
left_join(obs_2019_stats_tbl, by = "customer_id") %>%
replace_na(list(tnx_count = 0)) %>%
mutate(
tnx_count_qval = map2_dbl(sim_count, tnx_count, ~ ecdf(.x)(.y))
)
pnbd_hiermu_1000_validation_tbl %>% glimpse()## Rows: 1,000
## Columns: 6
## $ customer_id <chr> "C201101_0025", "C201101_0084", "C201101_0091", "C20110…
## $ sim_count <list> <0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tnx_count <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ first_tnx <dttm> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA…
## $ last_tnx <dttm> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA…
## $ tnx_count_qval <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
tnx_data_tbl <- obs_2019_stats_tbl %>%
semi_join(pnbd_hiermu_validsims_tbl, by = "customer_id")
obs_customer_count <- tnx_data_tbl %>% nrow()
obs_total_tnx_count <- tnx_data_tbl %>% pull(tnx_count) %>% sum()
pnbd_hiermu_tnx_simsumm_tbl <- pnbd_hiermu_validsims_tbl %>%
mutate(data = map(sim_file, read_rds)) %>%
unnest(data) %>%
group_by(draw_id) %>%
summarise(
.groups = "drop",
sim_customer_count = length(sim_tnx_count[sim_tnx_count > 0]),
sim_total_tnx_count = sum(sim_tnx_count)
)
ggplot(pnbd_hiermu_tnx_simsumm_tbl) +
geom_histogram(aes(x = sim_customer_count), binwidth = 1) +
geom_vline(aes(xintercept = obs_customer_count), colour = "red") +
labs(
x = "Simulated Customers With Transactions",
y = "Frequency",
title = "Histogram of Count of Customers Transacted",
subtitle = "Observed Count in Red"
)ggplot(pnbd_hiermu_tnx_simsumm_tbl) +
geom_histogram(aes(x = sim_total_tnx_count), binwidth = 5) +
geom_vline(aes(xintercept = obs_total_tnx_count), colour = "red") +
labs(
x = "Simulated Transaction Count",
y = "Frequency",
title = "Histogram of Count of Total Transaction Count",
subtitle = "Observed Count in Red"
)9 R Environment
options(width = 120L)
sessioninfo::session_info()## ─ Session info ───────────────────────────────────────────────────────────────────────────────────────────────────────
## setting value
## version R version 4.2.0 (2022-04-22)
## os Ubuntu 20.04.5 LTS
## system x86_64, linux-gnu
## ui RStudio
## language (EN)
## collate en_US.UTF-8
## ctype en_US.UTF-8
## tz Etc/UTC
## date 2022-10-26
## rstudio 2022.02.3+492 Prairie Trillium (server)
## pandoc 2.17.1.1 @ /usr/lib/rstudio-server/bin/quarto/bin/ (via rmarkdown)
##
## ─ Packages ───────────────────────────────────────────────────────────────────────────────────────────────────────────
## package * version date (UTC) lib source
## abind 1.4-5 2016-07-21 [1] RSPM (R 4.2.0)
## arrayhelpers 1.1-0 2020-02-04 [1] RSPM (R 4.2.0)
## assertthat 0.2.1 2019-03-21 [1] RSPM (R 4.2.0)
## backports 1.4.1 2021-12-13 [1] RSPM (R 4.2.0)
## base64enc 0.1-3 2015-07-28 [1] RSPM (R 4.2.0)
## bayesplot * 1.9.0 2022-03-10 [1] RSPM (R 4.2.0)
## bit 4.0.4 2020-08-04 [1] RSPM (R 4.2.0)
## bit64 4.0.5 2020-08-30 [1] RSPM (R 4.2.0)
## bookdown 0.27 2022-06-14 [1] RSPM (R 4.2.0)
## boot 1.3-28 2021-05-03 [2] CRAN (R 4.2.0)
## bridgesampling 1.1-2 2021-04-16 [1] RSPM (R 4.2.0)
## brms * 2.17.0 2022-09-26 [1] Github (paul-buerkner/brms@a43937c)
## Brobdingnag 1.2-7 2022-02-03 [1] RSPM (R 4.2.0)
## broom 0.8.0 2022-04-13 [1] RSPM (R 4.2.0)
## bslib 0.3.1 2021-10-06 [1] RSPM (R 4.2.0)
## cachem 1.0.6 2021-08-19 [1] RSPM (R 4.2.0)
## callr 3.7.0 2021-04-20 [1] RSPM (R 4.2.0)
## cellranger 1.1.0 2016-07-27 [1] RSPM (R 4.2.0)
## checkmate 2.1.0 2022-04-21 [1] RSPM (R 4.2.0)
## cli 3.3.0 2022-04-25 [1] RSPM (R 4.2.0)
## cmdstanr * 0.5.3 2022-09-26 [1] Github (stan-dev/cmdstanr@22b391e)
## coda 0.19-4 2020-09-30 [1] RSPM (R 4.2.0)
## codetools 0.2-18 2020-11-04 [2] CRAN (R 4.2.0)
## colorspace 2.0-3 2022-02-21 [1] RSPM (R 4.2.0)
## colourpicker 1.1.1 2021-10-04 [1] RSPM (R 4.2.0)
## conflicted * 1.1.0 2021-11-26 [1] RSPM (R 4.2.0)
## cowplot * 1.1.1 2020-12-30 [1] RSPM (R 4.2.0)
## crayon 1.5.1 2022-03-26 [1] RSPM (R 4.2.0)
## crosstalk 1.2.0 2021-11-04 [1] RSPM (R 4.2.0)
## curl 4.3.2 2021-06-23 [1] RSPM (R 4.2.0)
## DBI 1.1.3 2022-06-18 [1] RSPM (R 4.2.0)
## dbplyr 2.2.0 2022-06-05 [1] RSPM (R 4.2.0)
## digest 0.6.29 2021-12-01 [1] RSPM (R 4.2.0)
## directlabels * 2021.1.13 2021-01-16 [1] RSPM (R 4.2.0)
## distributional 0.3.0 2022-01-05 [1] RSPM (R 4.2.0)
## dplyr * 1.0.9 2022-04-28 [1] RSPM (R 4.2.0)
## DT 0.23 2022-05-10 [1] RSPM (R 4.2.0)
## dygraphs 1.1.1.6 2018-07-11 [1] RSPM (R 4.2.0)
## ellipsis 0.3.2 2021-04-29 [1] RSPM (R 4.2.0)
## evaluate 0.15 2022-02-18 [1] RSPM (R 4.2.0)
## fansi 1.0.3 2022-03-24 [1] RSPM (R 4.2.0)
## farver 2.1.0 2021-02-28 [1] RSPM (R 4.2.0)
## fastmap 1.1.0 2021-01-25 [1] RSPM (R 4.2.0)
## forcats * 0.5.1 2021-01-27 [1] RSPM (R 4.2.0)
## fs * 1.5.2 2021-12-08 [1] RSPM (R 4.2.0)
## furrr * 0.3.0 2022-05-04 [1] RSPM (R 4.2.0)
## future * 1.26.1 2022-05-27 [1] RSPM (R 4.2.0)
## gamm4 0.2-6 2020-04-03 [1] RSPM (R 4.2.0)
## generics 0.1.2 2022-01-31 [1] RSPM (R 4.2.0)
## ggdist 3.1.1 2022-02-27 [1] RSPM (R 4.2.0)
## ggplot2 * 3.3.6 2022-05-03 [1] RSPM (R 4.2.0)
## ggridges 0.5.3 2021-01-08 [1] RSPM (R 4.2.0)
## globals 0.15.0 2022-05-09 [1] RSPM (R 4.2.0)
## glue * 1.6.2 2022-02-24 [1] RSPM (R 4.2.0)
## gridExtra 2.3 2017-09-09 [1] RSPM (R 4.2.0)
## gtable 0.3.0 2019-03-25 [1] RSPM (R 4.2.0)
## gtools 3.9.2.2 2022-06-13 [1] RSPM (R 4.2.0)
## haven 2.5.0 2022-04-15 [1] RSPM (R 4.2.0)
## highr 0.9 2021-04-16 [1] RSPM (R 4.2.0)
## hms 1.1.1 2021-09-26 [1] RSPM (R 4.2.0)
## htmltools 0.5.2 2021-08-25 [1] RSPM (R 4.2.0)
## htmlwidgets 1.5.4 2021-09-08 [1] RSPM (R 4.2.0)
## httpuv 1.6.5 2022-01-05 [1] RSPM (R 4.2.0)
## httr 1.4.3 2022-05-04 [1] RSPM (R 4.2.0)
## igraph 1.3.2 2022-06-13 [1] RSPM (R 4.2.0)
## inline 0.3.19 2021-05-31 [1] RSPM (R 4.2.0)
## jquerylib 0.1.4 2021-04-26 [1] RSPM (R 4.2.0)
## jsonlite 1.8.0 2022-02-22 [1] RSPM (R 4.2.0)
## knitr 1.39 2022-04-26 [1] RSPM (R 4.2.0)
## labeling 0.4.2 2020-10-20 [1] RSPM (R 4.2.0)
## later 1.3.0 2021-08-18 [1] RSPM (R 4.2.0)
## lattice 0.20-45 2021-09-22 [2] CRAN (R 4.2.0)
## lifecycle 1.0.1 2021-09-24 [1] RSPM (R 4.2.0)
## listenv 0.8.0 2019-12-05 [1] RSPM (R 4.2.0)
## lme4 1.1-29 2022-04-07 [1] RSPM (R 4.2.0)
## loo 2.5.1 2022-03-24 [1] RSPM (R 4.2.0)
## lubridate 1.8.0 2021-10-07 [1] RSPM (R 4.2.0)
## magrittr * 2.0.3 2022-03-30 [1] RSPM (R 4.2.0)
## markdown 1.1 2019-08-07 [1] RSPM (R 4.2.0)
## MASS 7.3-56 2022-03-23 [2] CRAN (R 4.2.0)
## Matrix 1.4-1 2022-03-23 [2] CRAN (R 4.2.0)
## matrixStats 0.62.0 2022-04-19 [1] RSPM (R 4.2.0)
## memoise 2.0.1 2021-11-26 [1] RSPM (R 4.2.0)
## mgcv 1.8-40 2022-03-29 [2] CRAN (R 4.2.0)
## mime 0.12 2021-09-28 [1] RSPM (R 4.2.0)
## miniUI 0.1.1.1 2018-05-18 [1] RSPM (R 4.2.0)
## minqa 1.2.4 2014-10-09 [1] RSPM (R 4.2.0)
## modelr 0.1.8 2020-05-19 [1] RSPM (R 4.2.0)
## munsell 0.5.0 2018-06-12 [1] RSPM (R 4.2.0)
## mvtnorm 1.1-3 2021-10-08 [1] RSPM (R 4.2.0)
## nlme 3.1-157 2022-03-25 [2] CRAN (R 4.2.0)
## nloptr 2.0.3 2022-05-26 [1] RSPM (R 4.2.0)
## parallelly 1.32.0 2022-06-07 [1] RSPM (R 4.2.0)
## pillar 1.7.0 2022-02-01 [1] RSPM (R 4.2.0)
## pkgbuild 1.3.1 2021-12-20 [1] RSPM (R 4.2.0)
## pkgconfig 2.0.3 2019-09-22 [1] RSPM (R 4.2.0)
## plyr 1.8.7 2022-03-24 [1] RSPM (R 4.2.0)
## posterior * 1.2.2 2022-06-09 [1] RSPM (R 4.2.0)
## prettyunits 1.1.1 2020-01-24 [1] RSPM (R 4.2.0)
## processx 3.6.1 2022-06-17 [1] RSPM (R 4.2.0)
## projpred 2.1.2 2022-05-13 [1] RSPM (R 4.2.0)
## promises 1.2.0.1 2021-02-11 [1] RSPM (R 4.2.0)
## ps 1.7.1 2022-06-18 [1] RSPM (R 4.2.0)
## purrr * 0.3.4 2020-04-17 [1] RSPM (R 4.2.0)
## quadprog 1.5-8 2019-11-20 [1] RSPM (R 4.2.0)
## R6 2.5.1 2021-08-19 [1] RSPM (R 4.2.0)
## Rcpp * 1.0.8.3 2022-03-17 [1] RSPM (R 4.2.0)
## RcppParallel 5.1.5 2022-01-05 [1] RSPM (R 4.2.0)
## readr * 2.1.2 2022-01-30 [1] RSPM (R 4.2.0)
## readxl 1.4.0 2022-03-28 [1] RSPM (R 4.2.0)
## reprex 2.0.1 2021-08-05 [1] RSPM (R 4.2.0)
## reshape2 1.4.4 2020-04-09 [1] RSPM (R 4.2.0)
## rlang * 1.0.2 2022-03-04 [1] RSPM (R 4.2.0)
## rmarkdown 2.14 2022-04-25 [1] RSPM (R 4.2.0)
## rmdformats 1.0.4 2022-05-17 [1] RSPM (R 4.2.0)
## rstan 2.26.13 2022-09-26 [1] local
## rstantools 2.2.0 2022-04-08 [1] RSPM (R 4.2.0)
## rstudioapi 0.13 2020-11-12 [1] RSPM (R 4.2.0)
## rvest 1.0.2 2021-10-16 [1] RSPM (R 4.2.0)
## sass 0.4.1 2022-03-23 [1] RSPM (R 4.2.0)
## scales * 1.2.0 2022-04-13 [1] RSPM (R 4.2.0)
## sessioninfo 1.2.2 2021-12-06 [1] RSPM (R 4.2.0)
## shiny 1.7.1 2021-10-02 [1] RSPM (R 4.2.0)
## shinyjs 2.1.0 2021-12-23 [1] RSPM (R 4.2.0)
## shinystan 2.6.0 2022-03-03 [1] RSPM (R 4.2.0)
## shinythemes 1.2.0 2021-01-25 [1] RSPM (R 4.2.0)
## StanHeaders 2.26.13 2022-09-26 [1] local
## stringi 1.7.6 2021-11-29 [1] RSPM (R 4.2.0)
## stringr * 1.4.0 2019-02-10 [1] RSPM (R 4.2.0)
## svUnit 1.0.6 2021-04-19 [1] RSPM (R 4.2.0)
## tensorA 0.36.2 2020-11-19 [1] RSPM (R 4.2.0)
## threejs 0.3.3 2020-01-21 [1] RSPM (R 4.2.0)
## tibble * 3.1.7 2022-05-03 [1] RSPM (R 4.2.0)
## tidybayes * 3.0.2.9000 2022-09-26 [1] Github (mjskay/tidybayes@1efbdef)
## tidyr * 1.2.0 2022-02-01 [1] RSPM (R 4.2.0)
## tidyselect 1.1.2 2022-02-21 [1] RSPM (R 4.2.0)
## tidyverse * 1.3.1 2021-04-15 [1] RSPM (R 4.2.0)
## tzdb 0.3.0 2022-03-28 [1] RSPM (R 4.2.0)
## utf8 1.2.2 2021-07-24 [1] RSPM (R 4.2.0)
## V8 4.2.0 2022-05-14 [1] RSPM (R 4.2.0)
## vctrs 0.4.1 2022-04-13 [1] RSPM (R 4.2.0)
## vroom 1.5.7 2021-11-30 [1] RSPM (R 4.2.0)
## withr 2.5.0 2022-03-03 [1] RSPM (R 4.2.0)
## xfun 0.31 2022-05-10 [1] RSPM (R 4.2.0)
## xml2 1.3.3 2021-11-30 [1] RSPM (R 4.2.0)
## xtable 1.8-4 2019-04-21 [1] RSPM (R 4.2.0)
## xts 0.12.1 2020-09-09 [1] RSPM (R 4.2.0)
## yaml 2.3.5 2022-02-21 [1] RSPM (R 4.2.0)
## zoo 1.8-10 2022-04-15 [1] RSPM (R 4.2.0)
##
## [1] /usr/local/lib/R/site-library
## [2] /usr/local/lib/R/library
##
## ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
options(width = 80L)